# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Utility function for weight initialization"""

import math
import mindspore as msp
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal, initializer, HeNormal, Zero


def init_weights(cell: nn.Cell, fc_init_std=0.01, zero_init_final_bn=True):
    """
    Performs ResNet style weight initialization.
    Args:
        fc_init_std (float): the expected standard deviation for fc layer.
        zero_init_final_bn (bool): if True, zero initialize the final bn for
            every bottleneck.

    Follow the initialization method proposed in:
    {He, Kaiming, et al.
    "Delving deep into rectifiers: Surpassing human-level
    performance on imagenet classification."
    arXiv preprint arXiv:1502.01852 (2015)}
    """
    for _, m in cell.cells_and_names():
        if isinstance(m, nn.Conv3d):
            m.weight.set_data(initializer(
                HeNormal(math.sqrt(5), mode='fan_out', nonlinearity='relu'),
                m.weight.shape, m.weight.dtype))
            if m.bias is not None:
                m.bias.set_data(initializer(
                    Zero(), m.bias.shape, m.bias.dtype))
        elif isinstance(m, nn.BatchNorm3d):
            if (hasattr(m, "transform_final_bn")
                    and m.transform_final_bn and zero_init_final_bn):
                batchnorm_weight = 0.0
            else:
                batchnorm_weight = 1.0
            if m.bn2d.gamma is not None:
                fill = ops.Fill()
                m.bn2d.gamma.set_data(fill(
                    msp.float32, m.bn2d.gamma.shape, batchnorm_weight))
            if m.bn2d.beta is not None:
                zeroslike = ops.ZerosLike()
                m.bn2d.beta.set_data(zeroslike(m.bn2d.beta))
        if isinstance(m, nn.Dense):
            m.weight.set_data(initializer(
                Normal(sigma=fc_init_std, mean=0),
                shape=m.weight.shape, dtype=msp.float32))
            if m.bias is not None:
                zeroslike = ops.ZerosLike()
                m.bias.set_data(zeroslike(m.bias))
